{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Importing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-04-05T05:36:40.176641Z",
     "iopub.status.busy": "2023-04-05T05:36:40.176175Z",
     "iopub.status.idle": "2023-04-05T05:36:40.183738Z",
     "shell.execute_reply": "2023-04-05T05:36:40.182060Z",
     "shell.execute_reply.started": "2023-04-05T05:36:40.176595Z"
    }
   },
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "import torch\n",
    "import torchtext\n",
    "from torchtext.legacy import data\n",
    "from torchtext.legacy import datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-04-05T05:31:10.323147Z",
     "iopub.status.busy": "2023-04-05T05:31:10.322110Z",
     "iopub.status.idle": "2023-04-05T05:31:10.573771Z",
     "shell.execute_reply": "2023-04-05T05:31:10.572359Z",
     "shell.execute_reply.started": "2023-04-05T05:31:10.323099Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>clean_comment</th>\n",
       "      <th>category</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>family mormon have never tried explain them t...</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>buddhism has very much lot compatible with chr...</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>seriously don say thing first all they won get...</td>\n",
       "      <td>-1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>what you have learned yours and only yours wha...</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>for your own benefit you may want read living ...</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                       clean_comment  category\n",
       "0   family mormon have never tried explain them t...         1\n",
       "1  buddhism has very much lot compatible with chr...         1\n",
       "2  seriously don say thing first all they won get...        -1\n",
       "3  what you have learned yours and only yours wha...         0\n",
       "4  for your own benefit you may want read living ...         1"
      ]
     },
     "execution_count": 34,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "reddit = pd.read_csv(r\"C:\\Users\\User\\Downloads\\Reddit_Data.csv\")\n",
    "reddit.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Cleaning the text"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Converting to lower case"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-04-05T05:31:10.576638Z",
     "iopub.status.busy": "2023-04-05T05:31:10.576283Z",
     "iopub.status.idle": "2023-04-05T05:31:10.635408Z",
     "shell.execute_reply": "2023-04-05T05:31:10.634051Z",
     "shell.execute_reply.started": "2023-04-05T05:31:10.576603Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>clean_comment</th>\n",
       "      <th>category</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>family mormon have never tried explain them t...</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>buddhism has very much lot compatible with chr...</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>seriously don say thing first all they won get...</td>\n",
       "      <td>-1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>what you have learned yours and only yours wha...</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>for your own benefit you may want read living ...</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                       clean_comment  category\n",
       "0   family mormon have never tried explain them t...         1\n",
       "1  buddhism has very much lot compatible with chr...         1\n",
       "2  seriously don say thing first all they won get...        -1\n",
       "3  what you have learned yours and only yours wha...         0\n",
       "4  for your own benefit you may want read living ...         1"
      ]
     },
     "execution_count": 35,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "reddit['clean_comment'] = reddit['clean_comment'].str.lower()\n",
    "reddit.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Removing stop words"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-04-05T05:31:10.639671Z",
     "iopub.status.busy": "2023-04-05T05:31:10.639112Z",
     "iopub.status.idle": "2023-04-05T05:31:11.534211Z",
     "shell.execute_reply": "2023-04-05T05:31:11.533096Z",
     "shell.execute_reply.started": "2023-04-05T05:31:10.639613Z"
    }
   },
   "outputs": [],
   "source": [
    "from nltk.corpus import stopwords\n",
    "from nltk.tokenize import word_tokenize\n",
    "from nltk.tokenize.toktok import ToktokTokenizer\n",
    "\n",
    "tokenizer=ToktokTokenizer() \n",
    "stop_words = set(stopwords.words('english'))\n",
    "def remove_stopwords(text):\n",
    "    tokens = tokenizer.tokenize(text)\n",
    "    tokens = [token.strip() for token in tokens]\n",
    "    filtered_tokens = [token for token in tokens if token not in stop_words]\n",
    "    filtered_text = ' '.join(filtered_tokens)    \n",
    "    return filtered_text"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-04-05T05:31:11.537330Z",
     "iopub.status.busy": "2023-04-05T05:31:11.536064Z",
     "iopub.status.idle": "2023-04-05T05:31:14.510351Z",
     "shell.execute_reply": "2023-04-05T05:31:14.508972Z",
     "shell.execute_reply.started": "2023-04-05T05:31:11.537271Z"
    }
   },
   "outputs": [],
   "source": [
    "reddit['clean_comment'] = reddit.clean_comment.apply(remove_stopwords)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-04-05T05:31:14.512138Z",
     "iopub.status.busy": "2023-04-05T05:31:14.511795Z",
     "iopub.status.idle": "2023-04-05T05:31:14.521865Z",
     "shell.execute_reply": "2023-04-05T05:31:14.520426Z",
     "shell.execute_reply.started": "2023-04-05T05:31:14.512104Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0    family mormon never tried explain still stare ...\n",
       "1    buddhism much lot compatible christianity espe...\n",
       "2    seriously say thing first get complex explain ...\n",
       "3    learned want teach different focus goal wrappi...\n",
       "4    benefit may want read living buddha living chr...\n",
       "Name: clean_comment, dtype: object"
      ]
     },
     "execution_count": 38,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "reddit['clean_comment'].head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-04-04T13:30:19.801814Z",
     "iopub.status.busy": "2023-04-04T13:30:19.800097Z",
     "iopub.status.idle": "2023-04-04T13:30:19.808761Z",
     "shell.execute_reply": "2023-04-04T13:30:19.806966Z",
     "shell.execute_reply.started": "2023-04-04T13:30:19.801671Z"
    }
   },
   "source": [
    "### Removing URLs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-04-05T05:31:14.524251Z",
     "iopub.status.busy": "2023-04-05T05:31:14.523755Z",
     "iopub.status.idle": "2023-04-05T05:31:14.868488Z",
     "shell.execute_reply": "2023-04-05T05:31:14.867029Z",
     "shell.execute_reply.started": "2023-04-05T05:31:14.524197Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0    family mormon never tried explain still stare ...\n",
       "1    buddhism much lot compatible christianity espe...\n",
       "2    seriously say thing first get complex explain ...\n",
       "3    learned want teach different focus goal wrappi...\n",
       "4    benefit may want read living buddha living chr...\n",
       "Name: clean_comment, dtype: object"
      ]
     },
     "execution_count": 39,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import re\n",
    "\n",
    "def cleaning_URLs(data):\n",
    "    return re.sub('((www.[^s]+)|(https?://[^s]+))',' ',data)\n",
    "reddit['clean_comment'] = reddit['clean_comment'].apply(lambda x: cleaning_URLs(x))\n",
    "reddit['clean_comment'].head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Removing Numbers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-04-05T05:31:14.873533Z",
     "iopub.status.busy": "2023-04-05T05:31:14.873106Z",
     "iopub.status.idle": "2023-04-05T05:31:15.057535Z",
     "shell.execute_reply": "2023-04-05T05:31:15.056063Z",
     "shell.execute_reply.started": "2023-04-05T05:31:14.873491Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0    family mormon never tried explain still stare ...\n",
       "1    buddhism much lot compatible christianity espe...\n",
       "2    seriously say thing first get complex explain ...\n",
       "3    learned want teach different focus goal wrappi...\n",
       "4    benefit may want read living buddha living chr...\n",
       "Name: clean_comment, dtype: object"
      ]
     },
     "execution_count": 40,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def cleaning_numbers(data):\n",
    "    return re.sub('[0-9]+', '', data)\n",
    "reddit['clean_comment'] = reddit['clean_comment'].apply(lambda x: cleaning_numbers(x))\n",
    "reddit['clean_comment'].head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-04-05T05:31:15.059980Z",
     "iopub.status.busy": "2023-04-05T05:31:15.059251Z",
     "iopub.status.idle": "2023-04-05T05:31:15.066325Z",
     "shell.execute_reply": "2023-04-05T05:31:15.064757Z",
     "shell.execute_reply.started": "2023-04-05T05:31:15.059937Z"
    }
   },
   "outputs": [],
   "source": [
    "def remove_pattern(input_txt, pattern):\n",
    "    r = re.findall(pattern, input_txt)\n",
    "    for word in r:\n",
    "        input_txt = re.sub(word, \"\", input_txt)\n",
    "    return input_txt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-04-05T05:31:15.069793Z",
     "iopub.status.busy": "2023-04-05T05:31:15.069405Z",
     "iopub.status.idle": "2023-04-05T05:31:17.372060Z",
     "shell.execute_reply": "2023-04-05T05:31:17.370618Z",
     "shell.execute_reply.started": "2023-04-05T05:31:15.069738Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>clean_comment</th>\n",
       "      <th>category</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>family mormon never tried explain still stare ...</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>buddhism much lot compatible christianity espe...</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>seriously say thing first get complex explain ...</td>\n",
       "      <td>-1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>learned want teach different focus goal wrappi...</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>benefit may want read living buddha living chr...</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                       clean_comment  category\n",
       "0  family mormon never tried explain still stare ...         1\n",
       "1  buddhism much lot compatible christianity espe...         1\n",
       "2  seriously say thing first get complex explain ...        -1\n",
       "3  learned want teach different focus goal wrappi...         0\n",
       "4  benefit may want read living buddha living chr...         1"
      ]
     },
     "execution_count": 42,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "reddit['clean_comment'] = np.vectorize(remove_pattern)(reddit['clean_comment'], \"@[\\w]*\")\n",
    "reddit.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-04-05T05:31:17.374143Z",
     "iopub.status.busy": "2023-04-05T05:31:17.373794Z",
     "iopub.status.idle": "2023-04-05T05:31:17.619926Z",
     "shell.execute_reply": "2023-04-05T05:31:17.618385Z",
     "shell.execute_reply.started": "2023-04-05T05:31:17.374108Z"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\User\\AppData\\Local\\Temp\\ipykernel_14812\\4282037118.py:2: FutureWarning: The default value of regex will change from True to False in a future version.\n",
      "  reddit['clean_comment'] = reddit['clean_comment'].str.replace(r'[^a-zA-Z#]', ' ')\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>clean_comment</th>\n",
       "      <th>category</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>family mormon never tried explain still stare ...</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>buddhism much lot compatible christianity espe...</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>seriously say thing first get complex explain ...</td>\n",
       "      <td>-1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>learned want teach different focus goal wrappi...</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>benefit may want read living buddha living chr...</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                       clean_comment  category\n",
       "0  family mormon never tried explain still stare ...         1\n",
       "1  buddhism much lot compatible christianity espe...         1\n",
       "2  seriously say thing first get complex explain ...        -1\n",
       "3  learned want teach different focus goal wrappi...         0\n",
       "4  benefit may want read living buddha living chr...         1"
      ]
     },
     "execution_count": 43,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# remove special characters, numbers and punctuations\n",
    "reddit['clean_comment'] = reddit['clean_comment'].str.replace(r'[^a-zA-Z#]', ' ')\n",
    "reddit.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-04-05T05:31:17.622093Z",
     "iopub.status.busy": "2023-04-05T05:31:17.621750Z",
     "iopub.status.idle": "2023-04-05T05:31:17.818335Z",
     "shell.execute_reply": "2023-04-05T05:31:17.817048Z",
     "shell.execute_reply.started": "2023-04-05T05:31:17.622059Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>clean_comment</th>\n",
       "      <th>category</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>family mormon never tried explain still stare ...</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>buddhism much compatible christianity especial...</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>seriously thing first complex explain normal p...</td>\n",
       "      <td>-1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>learned want teach different focus goal wrappi...</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>benefit want read living buddha living christ ...</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                       clean_comment  category\n",
       "0  family mormon never tried explain still stare ...         1\n",
       "1  buddhism much compatible christianity especial...         1\n",
       "2  seriously thing first complex explain normal p...        -1\n",
       "3  learned want teach different focus goal wrappi...         0\n",
       "4  benefit want read living buddha living christ ...         1"
      ]
     },
     "execution_count": 44,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# remove short words, since they are not going to be giving any meaning like u, ur etc.,\n",
    "\n",
    "reddit['clean_comment'] = reddit['clean_comment'].apply(lambda x: \" \".join([w for w in x.split() if len(w)>3]))\n",
    "reddit.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [],
   "source": [
    "RANDOM_SEED = 123\n",
    "torch.manual_seed(RANDOM_SEED)\n",
    "\n",
    "VOCABULARY_SIZE = 20000\n",
    "LEARNING_RATE = 0.005\n",
    "BATCH_SIZE = 128\n",
    "NUM_EPOCHS = 15\n",
    "DEVICE = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')\n",
    "\n",
    "EMBEDDING_DIM = 128\n",
    "HIDDEN_DIM = 256\n",
    "NUM_CLASSES = 3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {},
   "outputs": [],
   "source": [
    "### Defining the feature processing\n",
    "\n",
    "TEXT = torchtext.legacy.data.Field(\n",
    "    tokenize='spacy', # default splits on whitespace\n",
    "    tokenizer_language='en_core_web_sm'\n",
    ")\n",
    "\n",
    "### Defining the label processing\n",
    "\n",
    "LABEL = torchtext.legacy.data.LabelField(dtype=torch.long)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {},
   "outputs": [],
   "source": [
    "fields = [('TEXT_COLUMN_NAME', TEXT), ('LABEL_COLUMN_NAME', LABEL)]\n",
    "\n",
    "dataset = torchtext.legacy.data.TabularDataset(\n",
    "    path='Reddit_Data.csv', format='csv',\n",
    "    skip_header=True, fields=fields)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Num Train: 29799\n",
      "Num Test: 7450\n"
     ]
    }
   ],
   "source": [
    "import random\n",
    "train_data, test_data = dataset.split(\n",
    "    split_ratio=[0.8, 0.2],\n",
    "    random_state=random.seed(RANDOM_SEED))\n",
    "\n",
    "print(f'Num Train: {len(train_data)}')\n",
    "print(f'Num Test: {len(test_data)}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Num Train: 25329\n",
      "Num Validation: 4470\n"
     ]
    }
   ],
   "source": [
    "train_data, valid_data = train_data.split(\n",
    "    split_ratio=[0.85, 0.15],\n",
    "    random_state=random.seed(RANDOM_SEED))\n",
    "\n",
    "print(f'Num Train: {len(train_data)}')\n",
    "print(f'Num Validation: {len(valid_data)}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'TEXT_COLUMN_NAME': [' ', 'there', 'any', 'explanation', 'for', 'why', 'all', 'the', 'giants', 'are', 'dead', 'jotunheim', 'seems', 'there', 'was', 'battle', 'there', 'blood', 'the', 'mountains', 'maybe', 'there', 'was', 'fight', 'about', 'what', 'with', 'loki', 'and', 'that', 'how', 'they', 'died', 'and', 'laufey', 'got', 'mortally', 'wounded'], 'LABEL_COLUMN_NAME': '-1'}\n"
     ]
    }
   ],
   "source": [
    "print(vars(train_data.examples[0]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Vocabulary size: 20002\n",
      "Number of classes: 3\n"
     ]
    }
   ],
   "source": [
    "TEXT.build_vocab(train_data, max_size=VOCABULARY_SIZE)\n",
    "LABEL.build_vocab(train_data)\n",
    "\n",
    "print(f'Vocabulary size: {len(TEXT.vocab)}')\n",
    "print(f'Number of classes: {len(LABEL.vocab)}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "defaultdict(None, {'1': 0, '0': 1, '-1': 2})\n"
     ]
    }
   ],
   "source": [
    "print(LABEL.vocab.stoi)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_loader, valid_loader, test_loader = \\\n",
    "    torchtext.legacy.data.BucketIterator.splits(\n",
    "        (train_data, valid_data, test_data),\n",
    "         batch_size=BATCH_SIZE,\n",
    "         sort_within_batch=False,\n",
    "         sort_key=lambda x: len(x.TEXT_COLUMN_NAME),\n",
    "         device=DEVICE\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train\n",
      "Text matrix size: torch.Size([594, 128])\n",
      "Target vector size: torch.Size([128])\n",
      "\n",
      "Valid:\n",
      "Text matrix size: torch.Size([2, 128])\n",
      "Target vector size: torch.Size([128])\n",
      "\n",
      "Test:\n",
      "Text matrix size: torch.Size([1, 128])\n",
      "Target vector size: torch.Size([128])\n"
     ]
    }
   ],
   "source": [
    "print('Train')\n",
    "for batch in train_loader:\n",
    "    print(f'Text matrix size: {batch.TEXT_COLUMN_NAME.size()}')\n",
    "    print(f'Target vector size: {batch.LABEL_COLUMN_NAME.size()}')\n",
    "    break\n",
    "    \n",
    "print('\\nValid:')\n",
    "for batch in valid_loader:\n",
    "    print(f'Text matrix size: {batch.TEXT_COLUMN_NAME.size()}')\n",
    "    print(f'Target vector size: {batch.LABEL_COLUMN_NAME.size()}')\n",
    "    break\n",
    "    \n",
    "print('\\nTest:')\n",
    "for batch in test_loader:\n",
    "    print(f'Text matrix size: {batch.TEXT_COLUMN_NAME.size()}')\n",
    "    print(f'Target vector size: {batch.LABEL_COLUMN_NAME.size()}')\n",
    "    break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "metadata": {},
   "outputs": [],
   "source": [
    "class RNN(torch.nn.Module):\n",
    "    \n",
    "    def __init__(self, input_dim, embedding_dim, hidden_dim, output_dim):\n",
    "        super().__init__()\n",
    "\n",
    "        self.embedding = torch.nn.Embedding(input_dim, embedding_dim)\n",
    "        #self.rnn = torch.nn.RNN(embedding_dim,\n",
    "        #                        hidden_dim,\n",
    "        #                        nonlinearity='relu')\n",
    "        self.rnn = torch.nn.LSTM(embedding_dim,\n",
    "                                 hidden_dim)        \n",
    "        \n",
    "        self.fc = torch.nn.Linear(hidden_dim, output_dim)\n",
    "        \n",
    "\n",
    "    def forward(self, text):\n",
    "        # text dim: [sentence length, batch size]\n",
    "        \n",
    "        embedded = self.embedding(text)\n",
    "        # embedded dim: [sentence length, batch size, embedding dim]\n",
    "        \n",
    "        output, (hidden, cell) = self.rnn(embedded)\n",
    "        # output dim: [sentence length, batch size, hidden dim]\n",
    "        # hidden dim: [1, batch size, hidden dim]\n",
    "\n",
    "        hidden.squeeze_(0)\n",
    "        # hidden dim: [batch size, hidden dim]\n",
    "        \n",
    "        output = self.fc(hidden)\n",
    "        return output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.manual_seed(RANDOM_SEED)\n",
    "model = RNN(input_dim=len(TEXT.vocab),\n",
    "            embedding_dim=EMBEDDING_DIM,\n",
    "            hidden_dim=HIDDEN_DIM,\n",
    "            output_dim=NUM_CLASSES # could use 1 for binary classification\n",
    ")\n",
    "\n",
    "model = model.to(DEVICE)\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=0.005)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 57,
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_accuracy(model, data_loader, device):\n",
    "\n",
    "    with torch.no_grad():\n",
    "\n",
    "        correct_pred, num_examples = 0, 0\n",
    "\n",
    "        for i, (features, targets) in enumerate(data_loader):\n",
    "\n",
    "            features = features.to(device)\n",
    "            targets = targets.float().to(device)\n",
    "\n",
    "            logits = model(features)\n",
    "            _, predicted_labels = torch.max(logits, 1)\n",
    "\n",
    "            num_examples += targets.size(0)\n",
    "            correct_pred += (predicted_labels == targets).sum()\n",
    "    return correct_pred.float()/num_examples * 100"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 58,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 001/015 | Batch 000/198 | Loss: 1.0990\n",
      "Epoch: 001/015 | Batch 050/198 | Loss: 1.0466\n",
      "Epoch: 001/015 | Batch 100/198 | Loss: 1.0545\n",
      "Epoch: 001/015 | Batch 150/198 | Loss: 1.0481\n",
      "training accuracy: 42.50%\n",
      "valid accuracy: 43.11%\n",
      "Time elapsed: 27.10 min\n",
      "Epoch: 002/015 | Batch 000/198 | Loss: 1.0672\n",
      "Epoch: 002/015 | Batch 050/198 | Loss: 1.0135\n",
      "Epoch: 002/015 | Batch 100/198 | Loss: 0.9781\n",
      "Epoch: 002/015 | Batch 150/198 | Loss: 0.6725\n",
      "training accuracy: 78.54%\n",
      "valid accuracy: 61.25%\n",
      "Time elapsed: 37.52 min\n",
      "Epoch: 003/015 | Batch 000/198 | Loss: 0.6198\n",
      "Epoch: 003/015 | Batch 050/198 | Loss: 0.3515\n",
      "Epoch: 003/015 | Batch 100/198 | Loss: 0.5450\n",
      "Epoch: 003/015 | Batch 150/198 | Loss: 0.3892\n",
      "training accuracy: 90.70%\n",
      "valid accuracy: 78.48%\n",
      "Time elapsed: 43.53 min\n",
      "Epoch: 004/015 | Batch 000/198 | Loss: 0.2666\n",
      "Epoch: 004/015 | Batch 050/198 | Loss: 0.2228\n",
      "Epoch: 004/015 | Batch 100/198 | Loss: 0.2836\n",
      "Epoch: 004/015 | Batch 150/198 | Loss: 0.2245\n",
      "training accuracy: 94.22%\n",
      "valid accuracy: 82.39%\n",
      "Time elapsed: 49.47 min\n",
      "Epoch: 005/015 | Batch 000/198 | Loss: 0.1771\n",
      "Epoch: 005/015 | Batch 050/198 | Loss: 0.2488\n",
      "Epoch: 005/015 | Batch 100/198 | Loss: 0.1787\n",
      "Epoch: 005/015 | Batch 150/198 | Loss: 0.2053\n",
      "training accuracy: 95.78%\n",
      "valid accuracy: 81.74%\n",
      "Time elapsed: 55.41 min\n",
      "Epoch: 006/015 | Batch 000/198 | Loss: 0.1745\n",
      "Epoch: 006/015 | Batch 050/198 | Loss: 0.0574\n",
      "Epoch: 006/015 | Batch 100/198 | Loss: 0.0953\n",
      "Epoch: 006/015 | Batch 150/198 | Loss: 0.2034\n",
      "training accuracy: 96.58%\n",
      "valid accuracy: 84.16%\n",
      "Time elapsed: 61.18 min\n",
      "Epoch: 007/015 | Batch 000/198 | Loss: 0.0702\n",
      "Epoch: 007/015 | Batch 050/198 | Loss: 0.1258\n",
      "Epoch: 007/015 | Batch 100/198 | Loss: 0.1193\n",
      "Epoch: 007/015 | Batch 150/198 | Loss: 0.1081\n",
      "training accuracy: 97.68%\n",
      "valid accuracy: 84.09%\n",
      "Time elapsed: 67.20 min\n",
      "Epoch: 008/015 | Batch 000/198 | Loss: 0.0626\n",
      "Epoch: 008/015 | Batch 050/198 | Loss: 0.1192\n",
      "Epoch: 008/015 | Batch 100/198 | Loss: 0.1864\n",
      "Epoch: 008/015 | Batch 150/198 | Loss: 0.0691\n",
      "training accuracy: 98.39%\n",
      "valid accuracy: 84.09%\n",
      "Time elapsed: 73.21 min\n",
      "Epoch: 009/015 | Batch 000/198 | Loss: 0.0684\n",
      "Epoch: 009/015 | Batch 050/198 | Loss: 0.0537\n",
      "Epoch: 009/015 | Batch 100/198 | Loss: 0.0214\n",
      "Epoch: 009/015 | Batch 150/198 | Loss: 0.0856\n",
      "training accuracy: 98.89%\n",
      "valid accuracy: 85.73%\n",
      "Time elapsed: 79.22 min\n",
      "Epoch: 010/015 | Batch 000/198 | Loss: 0.0633\n",
      "Epoch: 010/015 | Batch 050/198 | Loss: 0.0127\n",
      "Epoch: 010/015 | Batch 100/198 | Loss: 0.1077\n",
      "Epoch: 010/015 | Batch 150/198 | Loss: 0.0209\n",
      "training accuracy: 98.97%\n",
      "valid accuracy: 85.53%\n",
      "Time elapsed: 85.18 min\n",
      "Epoch: 011/015 | Batch 000/198 | Loss: 0.0315\n",
      "Epoch: 011/015 | Batch 050/198 | Loss: 0.0249\n",
      "Epoch: 011/015 | Batch 100/198 | Loss: 0.0982\n",
      "Epoch: 011/015 | Batch 150/198 | Loss: 0.0905\n",
      "training accuracy: 99.04%\n",
      "valid accuracy: 86.40%\n",
      "Time elapsed: 91.07 min\n",
      "Epoch: 012/015 | Batch 000/198 | Loss: 0.0410\n",
      "Epoch: 012/015 | Batch 050/198 | Loss: 0.0482\n",
      "Epoch: 012/015 | Batch 100/198 | Loss: 0.0096\n",
      "Epoch: 012/015 | Batch 150/198 | Loss: 0.0185\n",
      "training accuracy: 99.23%\n",
      "valid accuracy: 85.03%\n",
      "Time elapsed: 96.83 min\n",
      "Epoch: 013/015 | Batch 000/198 | Loss: 0.0408\n",
      "Epoch: 013/015 | Batch 050/198 | Loss: 0.0431\n",
      "Epoch: 013/015 | Batch 100/198 | Loss: 0.0198\n",
      "Epoch: 013/015 | Batch 150/198 | Loss: 0.0824\n",
      "training accuracy: 99.24%\n",
      "valid accuracy: 86.31%\n",
      "Time elapsed: 102.52 min\n",
      "Epoch: 014/015 | Batch 000/198 | Loss: 0.0247\n",
      "Epoch: 014/015 | Batch 050/198 | Loss: 0.0736\n",
      "Epoch: 014/015 | Batch 100/198 | Loss: 0.0315\n",
      "Epoch: 014/015 | Batch 150/198 | Loss: 0.0462\n",
      "training accuracy: 99.34%\n",
      "valid accuracy: 85.64%\n",
      "Time elapsed: 108.13 min\n",
      "Epoch: 015/015 | Batch 000/198 | Loss: 0.0340\n",
      "Epoch: 015/015 | Batch 050/198 | Loss: 0.0880\n",
      "Epoch: 015/015 | Batch 100/198 | Loss: 0.0032\n",
      "Epoch: 015/015 | Batch 150/198 | Loss: 0.0037\n",
      "training accuracy: 99.40%\n",
      "valid accuracy: 85.55%\n",
      "Time elapsed: 113.88 min\n",
      "Total Training Time: 113.88 min\n",
      "Test accuracy: 85.22%\n"
     ]
    }
   ],
   "source": [
    "import time\n",
    "import torch.nn.functional as F\n",
    "start_time = time.time()\n",
    "\n",
    "for epoch in range(NUM_EPOCHS):\n",
    "    model.train()\n",
    "    for batch_idx, batch_data in enumerate(train_loader):\n",
    "        \n",
    "        text = batch_data.TEXT_COLUMN_NAME.to(DEVICE)\n",
    "        labels = batch_data.LABEL_COLUMN_NAME.to(DEVICE)\n",
    "\n",
    "        ### FORWARD AND BACK PROP\n",
    "        logits = model(text)\n",
    "        loss = F.cross_entropy(logits, labels)\n",
    "        optimizer.zero_grad()\n",
    "        \n",
    "        loss.backward()\n",
    "        \n",
    "        ### UPDATE MODEL PARAMETERS\n",
    "        optimizer.step()\n",
    "        \n",
    "        ### LOGGING\n",
    "        if not batch_idx % 50:\n",
    "            print (f'Epoch: {epoch+1:03d}/{NUM_EPOCHS:03d} | '\n",
    "                   f'Batch {batch_idx:03d}/{len(train_loader):03d} | '\n",
    "                   f'Loss: {loss:.4f}')\n",
    "\n",
    "    with torch.set_grad_enabled(False):\n",
    "        print(f'training accuracy: '\n",
    "              f'{compute_accuracy(model, train_loader, DEVICE):.2f}%'\n",
    "              f'\\nvalid accuracy: '\n",
    "              f'{compute_accuracy(model, valid_loader, DEVICE):.2f}%')\n",
    "        \n",
    "    print(f'Time elapsed: {(time.time() - start_time)/60:.2f} min')\n",
    "    \n",
    "print(f'Total Training Time: {(time.time() - start_time)/60:.2f} min')\n",
    "print(f'Test accuracy: {compute_accuracy(model, test_loader, DEVICE):.2f}%')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 59,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Probability positive:\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "0.992868959903717"
      ]
     },
     "execution_count": 59,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import spacy\n",
    "\n",
    "\n",
    "nlp = spacy.blank(\"en\")\n",
    "\n",
    "def predict_sentiment(model, sentence):\n",
    "\n",
    "    model.eval()\n",
    "    tokenized = [tok.text for tok in nlp.tokenizer(sentence)]\n",
    "    indexed = [TEXT.vocab.stoi[t] for t in tokenized]\n",
    "    length = [len(indexed)]\n",
    "    tensor = torch.LongTensor(indexed).to(DEVICE)\n",
    "    tensor = tensor.unsqueeze(1)\n",
    "    length_tensor = torch.LongTensor(length)\n",
    "    prediction = torch.nn.functional.softmax(model(tensor), dim=1)\n",
    "    return prediction[0][0].item()\n",
    "\n",
    "print('Probability positive:')\n",
    "predict_sentiment(model, \"This is such an awesome movie, I really love it!\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 60,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Probability negative:\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "0.993397630751133"
      ]
     },
     "execution_count": 60,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "print('Probability negative:')\n",
    "1-predict_sentiment(model, \"I really hate this movie. It is really bad and sucks!\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
